package org.zjvis.datascience.service;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;

import java.sql.Types;
import java.util.*;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.constant.Constant;
import org.zjvis.datascience.common.constant.SemanticConstant;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.dto.DatasetDTO;
import org.zjvis.datascience.common.dto.RecommendDTO;
import org.zjvis.datascience.common.enums.PythonDateTypeFormatEnum;
import org.zjvis.datascience.common.enums.SemanticEnum;
import org.zjvis.datascience.common.enums.SemanticSubEnum;
import org.zjvis.datascience.common.etl.Sql;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.JwtUtil;
import org.zjvis.datascience.common.util.SemanticUtil;
import org.zjvis.datascience.common.util.SqlUtil;
import org.zjvis.datascience.common.util.db.JDBCUtil;
import org.zjvis.datascience.common.vo.DataPreviewVO;
import org.zjvis.datascience.common.vo.RecommendVO;
import org.zjvis.datascience.service.dataprovider.GPDataProvider;
import org.zjvis.datascience.service.dataset.DatasetService;
import org.zjvis.datascience.service.mapper.RecommendMapper;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.Statement;

/**
 * @description Recommend 推荐 Service
 * @date 2021-11-11
 */
@Service
public class RecommendService {

    private final static Logger logger = LoggerFactory.getLogger("RecommendService");

    private static String SQL_TPL = "select * from \"%s\".\"recommendPretreat\"('%s')";
    @Autowired
    DataPreviewService dataPreviewService;

    @Autowired
    GPDataProvider gpDataProvider;

    @Autowired
    DatasetService datasetService;

    @Autowired
    RecommendMapper recommendMapper;

    @Autowired
    SemanticService semanticService;

    public boolean update(RecommendDTO dto) {
        return recommendMapper.update(dto);
    }

    public Long save(RecommendDTO dto) {
        recommendMapper.save(dto);
        return dto.getId();
    }

    public RecommendDTO queryById(Long id) {
        return recommendMapper.queryById(id);
    }

    public List<String> queryTableNameById(Long id) {
        RecommendDTO dto = recommendMapper.queryById(id);
        String tableNames = dto.getTableNames();
        if (StringUtils.isEmpty(tableNames)) {
            logger.error("tableNames is empty!!!!");
            return new ArrayList<>();
        }
        String[] tables = tableNames.split(",");
        return Arrays.asList(tables);
    }

    private String getTableNameById(Long datasetId) {
        DatasetDTO datasetDTO = datasetService.queryById(datasetId);
        String dataJson = datasetDTO.getDataJson();
        JSONObject jsonObject = JSONObject.parseObject(dataJson);
        String table = jsonObject.getString("table");
        String schema = jsonObject.getString("schema");
        if (StringUtils.isEmpty(schema)) {
            schema = SqlTemplate.SOURCE_SCHEMA;
        }
        return String.format("%s.%s", schema, table);
    }

    public JSONArray getBins(String tableName, String col, int binCount){
        String sql = String.format("select min(%s) as min, max(%s) as max from %s", col, col, tableName);
        JSONArray data = gpDataProvider.executeQuery(sql);
        JSONObject meta = data.getJSONObject(0);
        float min = meta.getFloat("min");
        float max = meta.getFloat("max");
        binCount -= 1;
        float step = (max - min) / binCount;
        JSONArray bins = new JSONArray();
        for (int i = 0; i <= binCount; i++) {
            float left = min + step * i;
            float right = min + step * (i + 1);
            bins.add(new JSONArray(Arrays.asList(left, right)));
        }
        return bins;
    }

    public RecommendVO preCalculate(JSONObject param) {
        RecommendVO recommendVO = new RecommendVO();
        Connection conn = null;
        try {
            conn = gpDataProvider.getConn(1L);
            Statement st = conn.createStatement();

            JSONObject outputJson = new JSONObject();
            outputJson.put("params", param);
            outputJson.put("binCalculate", false);
            recommendVO.setUserId(JwtUtil.getCurrentUserId());
            String tableName = param.getString("tableName");
            String method = param.getString("method");
            JSONArray fields = param.getJSONArray("fields");

            String tmpTable = "pipeline.solid_rec_" + System.currentTimeMillis() + (long)(Math.random()* Constant.RANDOM);
            List<String> outputCols = new ArrayList<>();
            List<String> columnBody = new ArrayList<>();
            String joinBody = "";
            List<String> binCols = new ArrayList<>();
            List<Integer> types = new ArrayList<>();
            JSONArray binRanges = new JSONArray();
            String sql = "";
            String sql2 = "";
            for (Object obj: fields){
                JSONObject field = (JSONObject)obj;
                String col = field.getString("name");
                String key = SqlUtil.formatPGSqlColName(col);
                String colAfterOri = col + "_after";
                String colAfter = SqlUtil.formatPGSqlColName(colAfterOri);

                int type = field.getInteger("type");
                type = SqlHelper.mergeSqlType(type);
                types.add(type);

                String semantic = field.getString("semantic");

                JSONArray bins = new JSONArray();
                JSONArray binRange = new JSONArray();
                sql = String.format("select count(distinct(%s)) as count from %s;", key, tableName);
                int distinct = gpDataProvider.executeQuery(st, sql).getJSONObject(0).getInteger("count");
                sql = "";
                if (field.getInteger("binCount") != null && distinct > field.getInteger("binCount")){
                    int count = field.getInteger("binCount");
                    String columnBin = "";
                    if (type == Types.DATE){
                        int counter = 0;
                        for (PythonDateTypeFormatEnum item: PythonDateTypeFormatEnum.values()) {
                            if (counter == 5){
                                break;
                            }
                            counter += 1;
                            columnBin = String.format("pipeline.\"sys_func_format_time\"(%s::varchar, '%s') as %s", key, item.getVal(), colAfter);
                            sql2 = String.format("select count(*) from (select distinct pipeline.\"sys_func_format_time\"(%s::varchar, '%s')\n"
                                + "from %s) sub;", key, item.getVal(), tableName);
                            distinct = gpDataProvider.executeQuery(st, sql2).getJSONObject(0).getInteger("count");
                            if (distinct <= count){
                                break;
                            }
                        }
                        columnBody.add(columnBin);
                    } else if (null != semantic && semantic.equals(SemanticSubEnum.PROVINCE.getVal())) {
                        List<String> sqlItems = SemanticUtil
                            .semanticItems(key, SemanticConstant.PROVINCE_COL, "a.", "b.");
                        sql = String.format("create table %s as select distinct(_record_id_), b.%s as %s, a.country as %s"
                            + " from dataset._province_mapper_ a join %s b on %s;",
                            tmpTable, key, key, colAfter, tableName, Joiner.on(" or ").join(sqlItems));
                        columnBody.add(String.format("b.%s", colAfter));
                        joinBody = String.format("a join %s b on a.%s = b.%s", tmpTable, key, key);
                    } else if (null != semantic && semantic.equals(SemanticSubEnum.CITY.getVal())) {
                        List<String> sqlItems = SemanticUtil.semanticItems(key, SemanticConstant.CITY_COL, "a.", "b.");
                        sql2 = String.format("create table %s as select b.%s as %s, a.province as %s"
                            + " from dataset._city_mapper_ a join %s b on %s;",
                            tmpTable, key, key, colAfter, tableName, Joiner.on(" or ").join(sqlItems));
//                        gpDataProvider.executeSql(st, sql2);
                        sql2 += String.format("select count(distinct(%s)) as count from %s;", key, tmpTable);
                        distinct = gpDataProvider.executeQuery(st, sql2).getJSONObject(0).getInteger("count");
                        if (distinct >= count){
                            sqlItems.clear();
                            for (String item: SemanticConstant.PROVINCE_COL){
                                sqlItems.add(String.format("a.%s = b.%s", item, colAfter));
                            }
                            String tmpTable2 = "pipeline.solid_rec_" + System.currentTimeMillis() + (long)(Math.random()* Constant.RANDOM);
                            sql = String.format("create table %s as select b.%s as %s, a.country as %s"
                                + " from dataset._province_mapper_ a join %s b on %s;",
                                tmpTable2, key, key, colAfter, tmpTable, Joiner.on(" or ").join(sqlItems));
                            sql += String.format("drop table if exists %s;", tmpTable);
                            tmpTable = tmpTable2;
                        }
                        columnBody.add(String.format("b.%s", colAfter));
                        joinBody = String.format("a join %s b on a.%s = b.%s", tmpTable, key, key);
                    } else {
                        bins = getBins(tableName, key, count);
                        JSONArray bin = new JSONArray();
                        for (int i = 0; i < bins.size(); i++) {
                            bin = bins.getJSONArray(i);
                            binRange.add(bin.get(0));
                        }
                        binRange.add(bin.get(1));
                        binRanges.add(binRange);
                        binCols.add(colAfter);
                        columnBody.add(String.format("pipeline.\"getBinFloat\"(%s, '%s') as %s", key,
                            bins.toJSONString(), colAfter));
                        outputJson.put("binCalculate", true);
                    }
                } else {
                    columnBody.add(SqlUtil.formatPGSqlColName(col));
                    colAfterOri = col;
                }
                outputCols.add(colAfterOri);
            }

            String otherSql = String.format("%s group by (%s);", joinBody, Joiner.on(",").join(SqlUtil.formatPGSqlCols(outputCols)));
            String outputTable = "pipeline.solid_rec_" + System.currentTimeMillis() + (long)(Math.random()* Constant.RANDOM);
            if (param.getString("targetField") == null){
                sql += String.format("create table %s as select row_number() over() as _record_id_, %s, count(*) as count from %s %s",
                        outputTable, Joiner.on(",").join(columnBody), tableName, otherSql);
                outputCols.add("count");
            } else {
                JSONObject targetField = param.getJSONObject("targetField");
                String targetCol = targetField.getString("name");
                String outputCol = String.format("%s_%s", method, targetCol);
                String methodSql = String.format("%s(%s) as %s", method, SqlUtil.formatPGSqlColName(targetCol), SqlUtil.formatPGSqlColName(outputCol));
                sql += String.format("create table %s as select row_number() over() as _record_id_, %s, %s from %s %s",
                        outputTable, Joiner.on(",").join(columnBody), methodSql, tableName, otherSql);
                outputCols.add(outputCol);
            }
            outputJson.put("outputCols", outputCols);

//            sql += String.format("alter table %s alter column %s set default 0;", outputTable, outputCols.get(outputCols.size() - 1));
            if (binCols.size() == 1){
                String binCol = binCols.get(0);
                for (int k = 0; k < binRanges.getJSONArray(0).size(); k++) {
                    float value = binRanges.getJSONArray(0).getFloat(k);
                    sql += String.format(
                        "insert into %s(%s) select %s where not exists(select 1 from %s where %s=%s);",
                        outputTable, binCol, value, outputTable, binCol, value);
                }
            } else if (binCols.size() == 2) {
                String col1 = binCols.get(0);
                String col2 = binCols.get(1);
                for (int i = 0; i < binRanges.getJSONArray(0).size(); i++) {
                    float value1 = binRanges.getJSONArray(0).getFloat(i);
                    for (int j = 0; j < binRanges.getJSONArray(1).size(); j++) {
                        float value2 = binRanges.getJSONArray(1).getFloat(j);
                        sql += String.format(
                            "insert into %s(%s, %s) select %s, %s where not exists(select 1 from %s where %s=%s and %s=%s);",
                            outputTable, col1, col2, value1, value2, outputTable, col1, value1, col2, value2);
                    }
                }
            }

            for (int i = 0; i < types.size(); i++) {
                int type = types.get(i);
                String colAfter = SqlUtil.formatPGSqlColName(outputCols.get(i));
                if (type == Types.INTEGER) {
                    sql += String.format("alter table %s alter column %s type bigint;", outputTable, colAfter);
                } else if (type == Types.DECIMAL) {
                    sql += String.format("alter table %s alter column %s type decimal;", outputTable, colAfter);
                }
            }
            sql += String.format("alter table %s drop _record_id_;", outputTable);
            sql += String.format("drop table if exists %s;", tmpTable);
            gpDataProvider.executeSql(st, sql);

            recommendVO.setData(outputJson);
            recommendVO.setTableNames(outputTable);
            Long id = this.save(recommendVO.toTask());
            recommendVO.setId(id);
        } catch (Exception e) {
            logger.error(e.getMessage());
        } finally {
            JDBCUtil.close(conn, null, null);
        }
        return recommendVO;
    }

    public List<DataPreviewVO> viewData(RecommendVO vo) {
        String[] tableNames = vo.getTableNames().split(",");
        List<DataPreviewVO> dataPreviewVOS = new ArrayList<>();
        for (String tableName: tableNames) {
            DataPreviewVO dataPreviewVO = dataPreviewService.buildDataPreviewVO(tableName, 0L, 0L, 0L);
            dataPreviewVOS.add(dataPreviewVO);
        }
        return dataPreviewVOS;
    }

    public List<DataPreviewVO> viewDataById(Long id) {
        RecommendDTO dto = this.queryById(id);
        return viewData(dto.view());
    }

    public List<RecommendDTO> queryByModify(String startDate, String endDate){
        Map<String, String> map = new HashMap();
        map.put("startDate", startDate);
        map.put("endDate", endDate);
        return recommendMapper.queryByModify(map);
    }
}
